# @File    : common.py

import datetime
from abc import abstractmethod

from PyQt5 import QtCore
from PyQt5.QtCore import QTimer, QObject

import wraper
from ezmath import *
from wraper import *

BIG_NUM = 100000000
eps = 1e-7  # 两个浮点数之差小于这个值，则认为其相等
QT_GUI_PATH = os.path.dirname(__file__)
TEMP_DIR = os.path.join(QT_GUI_PATH, "temp")
# TODO: 避免多次执行
os.makedirs(TEMP_DIR, exist_ok=True)


class Purpose(enum.Enum):
    RECORD_BACKGROUND = 0  # 记录本底
    MEASUREMENT = 1


class Method(enum.Enum):
    USER_DEFINE = 0  # 用户自定义
    DEFAULT = 1


def jog_to(hm: libhallmachine.HallMachine, axis: Axis, target: float) -> bool:
    return HallMachineAxisHelper(hm).d_jog_to[axis](target)


def stop(hm: libhallmachine.HallMachine, axis: Axis) -> bool:
    return HallMachineAxisHelper(hm).d_stop[axis]()


def return_to_0(hm: libhallmachine.HallMachine, axis: Axis) -> bool:
    return HallMachineAxisHelper(hm).d_find_0[axis]()


class Trajectory:
    """
    规定高级的轨迹数据格式。
    由于不止被advanced界面使用，还被控制点计算器所使用，因此放在common中。
    """
    columns = "X/mm,Y/mm,Z/mm,A/degree,C/degree,I/A,records".split(",")
    kX_mm, kY_mm, kZ_mm, kA_deg, kC_deg, kI_A, kRecorders = columns
    columns_xyz = [kX_mm, kY_mm, kZ_mm]
    d_columns_axis = {
        kX_mm: Axis.X,
        kY_mm: Axis.Y,
        kZ_mm: Axis.Z,
        kA_deg: Axis.A,
        kC_deg: Axis.C
    }

    def __init__(self, ):
        self.df = pandas.DataFrame(columns=self.columns)

    def append(self, df: pandas.DataFrame):
        if not wraper.columns_is_match(df, self.columns):
            raise HeaderNotMatchException(self.columns)
        df = df.astype(float)
        self.df = pandas.concat([self.df, df], ignore_index=True)
        return self

    def fill_empty_or_na(self, hmhelper: HallMachineAxisHelper):
        for i in self.df.index:
            for col in {self.kX_mm, self.kY_mm, self.kZ_mm, self.kA_deg, self.kC_deg, }:
                if not numpy.isnan(self.df[col][i]):
                    continue
                self.df[col][i] = self.df[col][i - 1] if i > 0 else hmhelper.get_pos(self.d_columns_axis[col])
            if numpy.isnan(self.df[self.kRecorders][i]):
                self.df[self.kRecorders][i] = 1
            # TODO: 设置电流
            self.df[self.kI_A] = 0

    def append_n_rows_with_nan(self, n):
        self.df = append_n_rows_with_nan(self.df, n)

    def __getitem__(self, index: int):
        return self.df.loc[index]

    def __len__(self):
        return len(self.df)


class ControlPoint:
    """
    表示控制点表格中的一条记录
    """
    columns = Trajectory.columns
    kX_mm, kY_mm, kZ_mm, kA_deg, kC_deg, kI_A, kRecorders = columns
    columns_xyz = Trajectory.columns_xyz
    columns_all_except_records = columns[:-1]  # 除了“records”列之外的所有列

    def __init__(self, x_mm=0, y_mm=0, z_mm=0, A_deg=0, C_deg=0, I_A=0, records=0):
        """
        """
        records = 0 if not records else records
        self.series = pandas.Series(
            [x_mm, y_mm, z_mm, A_deg, C_deg, I_A, records],
            index=ControlPoint.columns)

    def __getitem__(self, item):
        return self.series[item]

    @staticmethod
    def from_hm(hm: HallMachine):
        umac_pos = UMAC_Pos.from_UMACData(hm.umac_data_)
        return ControlPoint(umac_pos.x, umac_pos.y, umac_pos.z, umac_pos.a, umac_pos.c, 0, 0)

    @staticmethod
    def from_Trajectory(traj: Trajectory, i):
        status: ControlPoint = ControlPoint()
        status.series = traj.df.loc[i]
        return status

    def add_into_traj(self, traj: Trajectory) -> Trajectory:
        traj.df = pandas.concat([traj.df, pandas.DataFrame([self.series])], ignore_index=True)
        return traj

    def __sub__(self, other):
        new_status = ControlPoint()
        new_status.series = self.series - other.series
        return new_status

    def distance(self, othor):
        delata = self - othor
        return numpy.linalg.norm(delata[self.columns_xyz])


class RecorderSupportPushHM(RecorderBase):
    """
    支持push_back(hm)
    """

    def __init__(self, autogen_path_method):
        super(RecorderSupportPushHM, self).__init__(autogen_path_method=autogen_path_method)

    def push_back(self, hm: HallMachine, tag=""):
        """
        :return:
        """
        umac_data: libhallmachine.UMACData = hm.umac_data_
        B: Vec3 = hm.B_
        V: libhallmachine.Vec3d = hm.V_
        self.df = pandas.concat(
            [self.df,
             pandas.DataFrame(
                 {self.k_timestamp: [time.time()],
                  self.k_x_mm: [umac_data.x_pos_mm],
                  self.k_y_mm: [umac_data.y_pos_mm],
                  self.k_z_mm: [umac_data.z_pos_mm],
                  self.k_s_mm: [numpy.nan],
                  self.k_A_deg: [umac_data.a_ang_deg],
                  self.k_C_deg: [umac_data.c_ang_deg],
                  self.k_Bx_Gauss: [B.x],
                  self.k_By_Gauss: [B.y],
                  self.k_Bz_Gauss: [B.z],
                  self.k_B_Gauss: [numpy.linalg.norm(vec3d_to_numpy(B))],
                  self.k_I_A: [0],
                  self.k_Vx_V: [V.x], self.k_Vy_V: [V.y], self.k_Vz_V: [V.z],
                  self.k_V_V: [numpy.linalg.norm(vec3d_to_numpy(V))],
                  self.k_tag: [tag]
                  })], ignore_index=True
        )
        length = self.df.shape[0]
        if length == 1:
            self.df[self.k_s_mm][length - 1] = 0
        else:
            self.df[self.k_s_mm][length - 1] = self.df[self.k_s_mm][length - 2] + self.distance(length - 1, length - 2)


class RecorderForBackground(RecorderSupportPushHM):
    def __init__(self):
        super(RecorderForBackground, self).__init__(lambda: wraper.DefaultFilePath.raw_VBackground)


def get_df_and_check(csvpath: str, columns: list) -> pandas.DataFrame:
    """
    若不存在，或列名不匹配，则新建一个满足列名要求的空的DataFrame
    :param csvpath:
    :param columns:
    :return:
    """
    default_df = pandas.DataFrame(columns=columns)
    if not os.path.exists(csvpath):
        return default_df
    df = pandas.read_csv(csvpath)
    # print(df)
    if columns_is_match(df, columns):
        return df
    return default_df


class RecorderForVtoB_One_Probe(RecorderSupportPushHM):
    """
    单个探头
    """
    d_used_columns = {
        Axis.X: [RecorderBase.k_Bx_Gauss, RecorderBase.k_Vx_V, RecorderBase.k_tag],
        Axis.Y: [RecorderBase.k_By_Gauss, RecorderBase.k_Vy_V, RecorderBase.k_tag],
        Axis.Z: [RecorderBase.k_Bz_Gauss, RecorderBase.k_Vz_V, RecorderBase.k_tag],
    }  # 表示各探头实际使用了表中的哪几列
    MIN_VALUE_COUNT_OF_B = 4

    def __init__(self, axis: Axis):
        super(RecorderForVtoB_One_Probe, self).__init__(lambda: wraper.DefaultFilePath.raw_V_to_B_componet(axis))
        self.axis = axis
        self.read_csv()
        self.used_columns = self.d_used_columns[axis]

    def read_csv(self, path: str = None):
        if not path:
            path = self.autogen_path_method()
        self.df = get_df_and_check(path, self.columns)

    def get_axis(self):
        return self.axis

    def get_update_time_str(self) -> str:
        path = self.autogen_path_method()
        return datetime.datetime.strftime(datetime.datetime.fromtimestamp(os.path.getmtime(path)),
                                          "%Y-%m-%d %H:%M") if os.path.exists(path) else "Unknown"

    def ready_to_write(self):
        """
        标定点数过少时，返回False
        :return:
        """
        tag_column = self.k_tag
        return len(set(self.df[tag_column])) > self.MIN_VALUE_COUNT_OF_B

    def to_csv(self):
        """
        更新两个文件：raw_V_to_B_axis 和 mean_V_to_B
        :return:
        """
        mean_XYZ_path = DefaultFilePath.mean_VtoB
        mean_XYZ_df = get_df_and_check(mean_XYZ_path, self.columns)
        print(mean_XYZ_df)
        B_column, V_column, tag_column = self.used_columns
        if not self.ready_to_write():
            print("暂不支持写入，因为数据点过少")
            return
        mean_XYZ_df[self.used_columns] = numpy.nan
        i = 0
        for tag in set(self.df[tag_column]):
            indexs = self.df[tag_column] == tag
            mean_V = self.df[indexs][V_column].astype(float).mean()
            # TODO: 检查所有的iloc/loc是否用错
            B = self.df[indexs][B_column].iloc[0]
            mean_XYZ_df.at[i, B_column] = B
            mean_XYZ_df.at[i, V_column] = mean_V
            i += 1
        mean_XYZ_recorder = RecorderBase(lambda: DefaultFilePath.mean_VtoB)
        order = mean_XYZ_df[V_column].argsort()
        mean_XYZ_df[self.used_columns] = mean_XYZ_df[self.used_columns].iloc[order].values
        mean_XYZ_recorder.df = mean_XYZ_df
        mean_XYZ_recorder.to_csv()
        super(RecorderForVtoB_One_Probe, self).to_csv()

    def get_brief(self) -> pandas.DataFrame:
        """
        :return: 仅包含有效数据的列
        """
        return self.df[self.used_columns]

    def push_back(self, hm: HallMachine, B_ground_truth: float):
        tag = "B:%.4f" % B_ground_truth
        hm_helper = HallMachineAxisHelper(hm)
        v = [B_ground_truth, hm_helper.d_get_V_relative[self.axis](), tag]
        self.df.at[self.df.shape[0], self.used_columns] = v


class Recorder(RecorderSupportPushHM):
    """
    规定测量记录的标准格式和读写方式
    """
    DEFAULT_RECORD_DIR = os.path.join(wraper.MAGLAB_DIR, r"qt_gui\record")
    DEFAULT_PREFIX = ""

    def __init__(self, record_dir=DEFAULT_RECORD_DIR, prefix=DEFAULT_PREFIX, ):
        print("创建了一个recorder，", record_dir, prefix)
        self.df = pandas.DataFrame(columns=self.columns)
        self.record_dir = record_dir
        self.prefix = prefix
        create_time = datetime.datetime.now()
        super(Recorder, self).__init__(lambda: os.path.join(
            self.record_dir,
            self.prefix + datetime.datetime.strftime(create_time, "_%Y%m%d_%H%M%S") + ".csv"
        ))


MAX_SPEED = {
    Axis.Z: 40 * 10,
    # TODO(Zijing): 检查是否合理
    Axis.X: 10 * 10,
    Axis.Y: 10 * 10,
    Axis.A: 10 * 10,
    Axis.C: 10 * 10,
}
BOUNDINGS = {
    Axis.X: libhallmachine.X_MAX_MM,
    Axis.Y: libhallmachine.Y_MAX_MM,
    Axis.Z: libhallmachine.Z_MAX_MM,
}


def set_axis_speed(hm: libhallmachine.HallMachine, axis: Axis, speed: int):
    return HallMachineAxisHelper(hm).d_set_speed[axis](speed)


class AbstractMeasurerTask(QObject):
    # 考虑到不同控制器的精度可能不一样
    EPS_DISTANCE = 1e-5  # 可忽略的距离变化
    EPS_ANG_A = 1  # 可忽略的角度变化
    EPS_ANG_C = 8  # 可忽略的角度变化
    EPS_I = eps  # 可忽略的电流变化
    sig_mean_data_updated = QtCore.pyqtSignal(RecorderBase)
    sig_raw_data_updated = QtCore.pyqtSignal(RecorderBase)
    sig_mean_data_done = QtCore.pyqtSignal(RecorderBase)
    sig_raw_data_done = QtCore.pyqtSignal(RecorderBase)
    """
    抽象的测量任务
    """

    def __init__(self):
        super(AbstractMeasurerTask, self).__init__()

    @abstractmethod
    def start(self):
        """
        开始测量
        :return:
        """

    @abstractmethod
    def stop(self):
        """
        结束测量
        :return:
        """


class DataExchangerTimer(QTimer):
    """
    计时器兼数据交换中介
    """
    sig_setup_canvas_on_calculator = QtCore.pyqtSignal()
    sig_insert_control_pts = QtCore.pyqtSignal(Trajectory)
    sig_plot_control_pts = QtCore.pyqtSignal(Trajectory)  # 绘制轨迹

    def __init__(self, *args, **kwargs):
        print("DataExchangerTimer object +1")
        super(DataExchangerTimer, self).__init__(*args, **kwargs)


DEFAULT_MAIN_TIMER = DataExchangerTimer()
DEFAULT_MAIN_TIMER.setInterval(200)
if __name__ == "__main__":
    arc_pts, center = gen_arc_from_p1_p2_rho(
        numpy.array([0, 0, 0]),
        numpy.array([1, 0, 0]),
        1,
        .1,
        numpy.array([0, 0, -1])
    )
    print(arc_pts)
    import matplotlib.pyplot as plt

    ax: plt.Axes = plt.axes(projection="3d")
    X, Y, Z = arc_pts[:, 0], arc_pts[:, 1], arc_pts[:, 2]
    ax.plot3D(X, Y, Z)
    ax.set_ylim([0, 1])
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_zlabel("z")
    fig2 = plt.figure()
    plt.plot(arc_pts[:, 0], arc_pts[:, 1], ".")
    plt.axis("equal")
    plt.show()
